// SPDX-FileCopyrightText: 2011-2019 Disney Enterprises, Inc.
// SPDX-License-Identifier: LicenseRef-Apache-2.0
// SPDX-FileCopyrightText: 2020 L. E. Segovia <amy@amyspark.me>
// SPDX-License-Identifier: GPL-3.0-or-later

#include <cstdint>

#include "ExprConfig.h"
#include "ExprLLVMAll.h"
#include "ExprNode.h"
#include "VarBlock.h"

extern "C" void KSeExprLLVMEvalFPVarRef(KSeExpr::ExprVarRef *seVR, double *result);
extern "C" void KSeExprLLVMEvalStrVarRef(KSeExpr::ExprVarRef *seVR, double *result);
extern "C" void KSeExprLLVMEvalCustomFunction(int *opDataArg, double *fpArg, char **strArg, void **funcdata, const KSeExpr::ExprFuncNode *node);

namespace KSeExpr
{
#if defined(SEEXPR_ENABLE_LLVM)

LLVM_VALUE promoteToDim(LLVM_VALUE val, unsigned dim, llvm::IRBuilder<> &Builder);

class LLVMEvaluator
{
    // TODO: this seems needlessly complex, let's fix it
    // TODO: let the dev code allocate memory?
    // FP is the native function for this expression.
    template<class T> class LLVMEvaluationContext
    {
    private:
        using FunctionPtr = void (*)(T *, char **, uint32_t);
        using FunctionPtrMultiple = void (*)(char **, uint32_t, uint32_t, uint32_t);
        FunctionPtr functionPtr{nullptr};
        FunctionPtrMultiple functionPtrMultiple{nullptr};
        T *resultData{nullptr};

    public:
        LLVMEvaluationContext(const LLVMEvaluationContext &) = delete;
        LLVMEvaluationContext &operator=(const LLVMEvaluationContext &) = delete;
        LLVMEvaluationContext(LLVMEvaluationContext &&) noexcept = default;
        LLVMEvaluationContext& operator=(LLVMEvaluationContext &&) noexcept = default;
        ~LLVMEvaluationContext()
        {
            delete[] resultData;
        }
        LLVMEvaluationContext() = default;
        
        void init(void *fp, void *fpLoop, int dim)
        {
            reset();
            functionPtr = reinterpret_cast<FunctionPtr>(fp);
            functionPtrMultiple = reinterpret_cast<FunctionPtrMultiple>(fpLoop);
            resultData = new T[dim];
        }
        void reset()
        {
            delete[] resultData;
            resultData = nullptr;
            functionPtr = nullptr;
            resultData = nullptr;
        }
        const T *operator()(VarBlock *varBlock)
        {
            assert(functionPtr && resultData);
            functionPtr(resultData, varBlock ? varBlock->data() : nullptr, varBlock ? varBlock->indirectIndex : 0);
            return resultData;
        }
        void operator()(VarBlock *varBlock, size_t outputVarBlockOffset, size_t rangeStart, size_t rangeEnd)
        {
            assert(functionPtr && resultData);
            functionPtrMultiple(varBlock ? varBlock->data() : nullptr, outputVarBlockOffset, rangeStart, rangeEnd);
        }
    };
    std::unique_ptr<LLVMEvaluationContext<double>> _llvmEvalFP;
    std::unique_ptr<LLVMEvaluationContext<char *>> _llvmEvalStr;

    std::unique_ptr<llvm::LLVMContext> _llvmContext;
    std::unique_ptr<llvm::ExecutionEngine> TheExecutionEngine;

public:
    LLVMEvaluator() = default;

    const char *evalStr(VarBlock *varBlock)
    {
        return *(*_llvmEvalStr)(varBlock);
    }
    const double *evalFP(VarBlock *varBlock)
    {
        return (*_llvmEvalFP)(varBlock);
    }

    void evalMultiple(VarBlock *varBlock, uint32_t outputVarBlockOffset, uint32_t rangeStart, uint32_t rangeEnd)
    {
        return (*_llvmEvalFP)(varBlock, outputVarBlockOffset, rangeStart, rangeEnd);
    }

    void debugPrint()
    {
        // TheModule->print(llvm::errs(), nullptr);
    }

    bool prepLLVM(ExprNode *parseTree, const ExprType &desiredReturnType)
    {
        using namespace llvm;
        InitializeNativeTarget();
        InitializeNativeTargetAsmPrinter();
        InitializeNativeTargetAsmParser();

        std::string uniqueName = getUniqueName();

        // create Module
        _llvmContext = std::make_unique<LLVMContext>();

        std::unique_ptr<Module> TheModule(new Module(uniqueName + "_module", *_llvmContext));

        // create all needed types
        Type *i8PtrTy = Type::getInt8PtrTy(*_llvmContext);                 // char *
        PointerType *i8PtrPtrTy = PointerType::getUnqual(i8PtrTy);         // char **
        PointerType *i8PtrPtrPtrTy = PointerType::getUnqual(i8PtrPtrTy);   // char ***
        Type *i32Ty = Type::getInt32Ty(*_llvmContext);                     // int
        Type *i32PtrTy = Type::getInt32PtrTy(*_llvmContext);               // int *
        Type *i64Ty = Type::getInt64Ty(*_llvmContext);                     // int64 *
        Type *doublePtrTy = Type::getDoublePtrTy(*_llvmContext);           // double *
        PointerType *doublePtrPtrTy = PointerType::getUnqual(doublePtrTy); // double **
        Type *voidTy = Type::getVoidTy(*_llvmContext);                     // void

        // create bindings to helper functions for variables and fucntions
        Function *KSeExprLLVMEvalCustomFunctionFunc = nullptr;
        Function *KSeExprLLVMEvalFPVarRefFunc = nullptr;
        Function *KSeExprLLVMEvalStrVarRefFunc = nullptr;
        Function *KSeExprLLVMEvalstrlenFunc = nullptr;
        Function *KSeExprLLVMEvalmallocFunc = nullptr;
        Function *KSeExprLLVMEvalfreeFunc = nullptr;
        Function *KSeExprLLVMEvalmemsetFunc = nullptr;
        Function *KSeExprLLVMEvalstrcatFunc = nullptr;
        Function *KSeExprLLVMEvalstrcmpFunc = nullptr;
        {
            {
                FunctionType *FT = FunctionType::get(voidTy, {i32PtrTy, doublePtrTy, i8PtrPtrTy, i8PtrPtrTy, i64Ty}, false);
                KSeExprLLVMEvalCustomFunctionFunc = Function::Create(FT, GlobalValue::ExternalLinkage, "KSeExprLLVMEvalCustomFunction", TheModule.get());
            }
            {
                FunctionType *FT = FunctionType::get(voidTy, {i8PtrTy, doublePtrTy}, false);
                KSeExprLLVMEvalFPVarRefFunc = Function::Create(FT, GlobalValue::ExternalLinkage, "KSeExprLLVMEvalFPVarRef", TheModule.get());
            }
            {
                FunctionType *FT = FunctionType::get(voidTy, {i8PtrTy, i8PtrPtrTy}, false);
                KSeExprLLVMEvalStrVarRefFunc = Function::Create(FT, GlobalValue::ExternalLinkage, "KSeExprLLVMEvalStrVarRef", TheModule.get());
            }
            {
                FunctionType *FT = FunctionType::get(i32Ty, {i8PtrTy}, false);
                KSeExprLLVMEvalstrlenFunc = Function::Create(FT, Function::ExternalLinkage, "strlen", TheModule.get());
            }
            {
                FunctionType *FT = FunctionType::get(i8PtrTy, {i32Ty}, false);
                KSeExprLLVMEvalmallocFunc = Function::Create(FT, Function::ExternalLinkage, "malloc", TheModule.get());
            }
            {
                FunctionType *FT = FunctionType::get(voidTy, {i8PtrTy}, false);
                KSeExprLLVMEvalfreeFunc = Function::Create(FT, Function::ExternalLinkage, "free", TheModule.get());
            }
            {
                FunctionType *FT = FunctionType::get(voidTy, {i8PtrTy, i32Ty, i32Ty}, false);
                KSeExprLLVMEvalmemsetFunc = Function::Create(FT, Function::ExternalLinkage, "memset", TheModule.get());
            }
            {
                FunctionType *FT = FunctionType::get(i8PtrTy, {i8PtrTy, i8PtrTy}, false);
                KSeExprLLVMEvalstrcatFunc = Function::Create(FT, Function::ExternalLinkage, "strcat", TheModule.get());
            }
            {
                FunctionType *FT = FunctionType::get(i32Ty, {i8PtrTy, i8PtrTy}, false);
                KSeExprLLVMEvalstrcmpFunc = Function::Create(FT, Function::ExternalLinkage, "strcmp", TheModule.get());
            }
        }

        // create function and entry BB
        bool desireFP = desiredReturnType.isFP();
        std::array<Type *, 3> ParamTys = {desireFP ? doublePtrTy : i8PtrPtrTy, doublePtrPtrTy, i32Ty};
        FunctionType *FT = FunctionType::get(voidTy, ParamTys, false);
        Function *F = Function::Create(FT, Function::ExternalLinkage, uniqueName + "_func", TheModule.get());
#if LLVM_VERSION_MAJOR > 4
        F->addAttribute(llvm::AttributeList::FunctionIndex, llvm::Attribute::AlwaysInline);
#else
        F->addAttribute(llvm::AttributeSet::FunctionIndex, llvm::Attribute::AlwaysInline);
#endif
        {
            // label the function with names
            std::array<const char *, 3> names = {"outputPointer", "dataBlock", "indirectIndex"};
            int idx = 0;
            for (auto &arg : F->args())
                arg.setName(names[idx++]);
        }

        auto dimDesired = desiredReturnType.dim();
        auto dimGenerated = parseTree->type().dim();
        {
            BasicBlock *BB = BasicBlock::Create(*_llvmContext, "entry", F);
            IRBuilder<> Builder(BB);

            // codegen
            Value *lastVal = parseTree->codegen(Builder);

            // return values through parameter.
            Value *firstArg = &*F->arg_begin();
            if (desireFP) {
                Value *newLastVal = promoteToDim(lastVal, dimDesired, Builder);
                if (newLastVal->getType()->isVectorTy()) {
                    // Output is vector - copy values (if possible)

                    assert(dimDesired >= 1 && "error. dim of FP is less than 1.");

                    assert(dimGenerated >= 1 && "error. dim of FP is less than 1.");

                    assert(dimGenerated == 1 || dimGenerated >= dimDesired && "error: unable to match between FP of differing dimensions");

                    auto *VT = llvm::cast<llvm::VectorType>(newLastVal->getType());
#if LLVM_VERSION_MAJOR >= 13
                    if (VT && VT->getElementCount().getKnownMinValue() >= dimDesired) {
#else
                    if (VT && VT->getNumElements() >= dimDesired) {
#endif
                        for (unsigned i = 0; i < dimDesired; ++i) {
                            Value *idx = ConstantInt::get(Type::getInt64Ty(*_llvmContext), i);
                            Value *val = Builder.CreateExtractElement(newLastVal, idx);
                            Value *ptr = IN_BOUNDS_GEP(Builder, firstArg, idx);
                            Builder.CreateStore(val, ptr);
                        }
                    } else {
                        for (unsigned i = 0; i < dimDesired; ++i) {
                            Value *idx = ConstantInt::get(Type::getInt64Ty(*_llvmContext), i);
                            Value *original_idx = ConstantInt::get(Type::getInt64Ty(*_llvmContext), 0);
                            Value *val = Builder.CreateExtractElement(newLastVal, original_idx);
                            Value *ptr = IN_BOUNDS_GEP(Builder, firstArg, idx);
                            Builder.CreateStore(val, ptr);
                        }
                    }
                } else {
                    if (dimGenerated > 1) {
                        Value *newLastVal = promoteToDim(lastVal, dimDesired, Builder);
#ifndef NDEBUG
                        auto *VT = llvm::cast<llvm::VectorType>(newLastVal->getType());
#if LLVM_VERSION_MAJOR >= 13
                        assert(VT && VT->getElementCount().getKnownMinValue() >= dimDesired);
#else
                        assert(VT && VT->getNumElements() >= dimDesired);
#endif
#endif
                        for (unsigned i = 0; i < dimDesired; ++i) {
                            Value *idx = ConstantInt::get(Type::getInt64Ty(*_llvmContext), i);
                            Value *val = Builder.CreateExtractElement(newLastVal, idx);
                            Value *ptr = IN_BOUNDS_GEP(Builder, firstArg, idx);
                            Builder.CreateStore(val, ptr);
                        }
                    } else if (dimGenerated == 1) {
                        for (unsigned i = 0; i < dimDesired; ++i) {
                            Value *ptr = Builder.CreateConstInBoundsGEP1_32(nullptr, firstArg, i);
                            Builder.CreateStore(lastVal, ptr);
                        }
                    } else {
                        assert(false && "error. dim of FP is less than 1.");
                    }
                }
            } else {
                Builder.CreateStore(lastVal, firstArg);
            }

            Builder.CreateRetVoid();
        }

        // write a new function
        FunctionType *FTLOOP = FunctionType::get(voidTy, {i8PtrTy, i32Ty, i32Ty, i32Ty}, false);
        Function *FLOOP = Function::Create(FTLOOP, Function::ExternalLinkage, uniqueName + "_loopfunc", TheModule.get());
        {
            // label the function with names
            std::array<const char *, 4> names = {"dataBlock", "outputVarBlockOffset", "rangeStart", "rangeEnd"};
            int idx = 0;
            for (auto &arg : FLOOP->args()) {
                arg.setName(names[idx++]);
            }
        }
        {
            // Local variables
            Value *dimValue = ConstantInt::get(i32Ty, dimDesired);
            Value *oneValue = ConstantInt::get(i32Ty, 1);

            // Basic blocks
            BasicBlock *entryBlock = BasicBlock::Create(*_llvmContext, "entry", FLOOP);
            BasicBlock *loopCmpBlock = BasicBlock::Create(*_llvmContext, "loopCmp", FLOOP);
            BasicBlock *loopRepeatBlock = BasicBlock::Create(*_llvmContext, "loopRepeat", FLOOP);
            BasicBlock *loopIncBlock = BasicBlock::Create(*_llvmContext, "loopInc", FLOOP);
            BasicBlock *loopEndBlock = BasicBlock::Create(*_llvmContext, "loopEnd", FLOOP);
            IRBuilder<> Builder(entryBlock);
            Builder.SetInsertPoint(entryBlock);

            // Get arguments
            Function::arg_iterator argIterator = FLOOP->arg_begin();
            Value *varBlockCharPtrPtrArg = &*argIterator;
            ++argIterator;
            Value *outputVarBlockOffsetArg = &*argIterator;
            ++argIterator;
            Value *rangeStartArg = &*argIterator;
            ++argIterator;
            Value *rangeEndArg = &*argIterator;
            ++argIterator;

            // Allocate Variables
            Value *rangeStartVar = Builder.CreateAlloca(Type::getInt32Ty(*_llvmContext), oneValue, "rangeStartVar");
            Value *rangeEndVar = Builder.CreateAlloca(Type::getInt32Ty(*_llvmContext), oneValue, "rangeEndVar");
            Value *indexVar = Builder.CreateAlloca(Type::getInt32Ty(*_llvmContext), oneValue, "indexVar");
            Value *outputVarBlockOffsetVar = Builder.CreateAlloca(Type::getInt32Ty(*_llvmContext), oneValue, "outputVarBlockOffsetVar");
            Value *varBlockDoublePtrPtrVar = Builder.CreateAlloca(doublePtrPtrTy, oneValue, "varBlockDoublePtrPtrVar");
            Value *varBlockTPtrPtrVar = Builder.CreateAlloca(desireFP == true ? doublePtrPtrTy : i8PtrPtrPtrTy, oneValue, "varBlockTPtrPtrVar");

            // Copy variables from args
            Builder.CreateStore(Builder.CreatePointerCast(varBlockCharPtrPtrArg, doublePtrPtrTy, "varBlockAsDoublePtrPtr"), varBlockDoublePtrPtrVar);
            Builder.CreateStore(Builder.CreatePointerCast(varBlockCharPtrPtrArg, desireFP ? doublePtrPtrTy : i8PtrPtrPtrTy, "varBlockAsTPtrPtr"), varBlockTPtrPtrVar);
            Builder.CreateStore(rangeStartArg, rangeStartVar);
            Builder.CreateStore(rangeEndArg, rangeEndVar);
            Builder.CreateStore(outputVarBlockOffsetArg, outputVarBlockOffsetVar);

            // Set output pointer
            Value *outputBasePtrPtr = Builder.CreateGEP(nullptr, CREATE_LOAD(Builder, varBlockTPtrPtrVar), outputVarBlockOffsetArg, "outputBasePtrPtr");
            Value *outputBasePtr = CREATE_LOAD_WITH_ID(Builder, outputBasePtrPtr, "outputBasePtr");
            Builder.CreateStore(CREATE_LOAD(Builder, rangeStartVar), indexVar);

            Builder.CreateBr(loopCmpBlock);
            Builder.SetInsertPoint(loopCmpBlock);
            Value *cond = Builder.CreateICmpULT(CREATE_LOAD(Builder, indexVar), CREATE_LOAD(Builder, rangeEndVar));
            Builder.CreateCondBr(cond, loopRepeatBlock, loopEndBlock);

            Builder.SetInsertPoint(loopRepeatBlock);
            Value *myOutputPtr = Builder.CreateGEP(nullptr, outputBasePtr, Builder.CreateMul(dimValue, CREATE_LOAD(Builder, indexVar)));
            Builder.CreateCall(F, {myOutputPtr, CREATE_LOAD(Builder, varBlockDoublePtrPtrVar), CREATE_LOAD(Builder, indexVar)});

            Builder.CreateBr(loopIncBlock);

            Builder.SetInsertPoint(loopIncBlock);
            Builder.CreateStore(Builder.CreateAdd(CREATE_LOAD(Builder, indexVar), oneValue), indexVar);
            Builder.CreateBr(loopCmpBlock);

            Builder.SetInsertPoint(loopEndBlock);
            Builder.CreateRetVoid();
        }

        if (Expression::debugging) {
#ifdef DEBUG
            std::cerr << "Pre verified LLVM byte code " << std::endl;
            TheModule->print(llvm::errs(), nullptr);
#endif
        }

        // TODO: Find out if there is a new way to veirfy
        // if (verifyModule(*TheModule)) {
        //     std::cerr << "Logic error in code generation of LLVM alert developers" << std::endl;
        //     TheModule->print(llvm::errs(), nullptr);
        // }
        Module *altModule = TheModule.get();
        std::string ErrStr;
        TheExecutionEngine.reset(EngineBuilder(std::move(TheModule))
                                     .setErrorStr(&ErrStr)
                                     //     .setUseMCJIT(true)
                                     .setOptLevel(CodeGenOpt::Aggressive)
                                     .create());

        altModule->setDataLayout(TheExecutionEngine->getDataLayout());

        // Add bindings to C linkage helper functions
        TheExecutionEngine->addGlobalMapping(KSeExprLLVMEvalFPVarRefFunc,
            reinterpret_cast<void *>(KSeExprLLVMEvalFPVarRef));
        TheExecutionEngine->addGlobalMapping(KSeExprLLVMEvalStrVarRefFunc,
            reinterpret_cast<void *>(KSeExprLLVMEvalStrVarRef));
        TheExecutionEngine->addGlobalMapping(KSeExprLLVMEvalCustomFunctionFunc,
            reinterpret_cast<void *>(KSeExprLLVMEvalCustomFunction));
        TheExecutionEngine->addGlobalMapping(KSeExprLLVMEvalstrlenFunc,
            reinterpret_cast<void *>(strlen));
        TheExecutionEngine->addGlobalMapping(KSeExprLLVMEvalstrcatFunc,
            reinterpret_cast<void *>(strcat));
        TheExecutionEngine->addGlobalMapping(KSeExprLLVMEvalstrcmpFunc,
            reinterpret_cast<void *>(strcmp));
        TheExecutionEngine->addGlobalMapping(KSeExprLLVMEvalmemsetFunc,
            reinterpret_cast<void *>(memset));
        TheExecutionEngine->addGlobalMapping(KSeExprLLVMEvalmallocFunc,
            reinterpret_cast<void *>(malloc));
        TheExecutionEngine->addGlobalMapping(KSeExprLLVMEvalfreeFunc,
            reinterpret_cast<void *>(free));

        // [verify]
        std::string errorStr;
        llvm::raw_string_ostream raw(errorStr);
        if (llvm::verifyModule(*altModule, &raw)) {
            parseTree->addError(ErrorCode::Unknown, {errorStr});
            return false;
        }

        // Setup optimization
        llvm::PassManagerBuilder builder;
        std::unique_ptr<llvm::legacy::PassManager> pm(new llvm::legacy::PassManager);
        std::unique_ptr<llvm::legacy::FunctionPassManager> fpm(new llvm::legacy::FunctionPassManager(altModule));
        builder.OptLevel = 3;
#if (LLVM_VERSION_MAJOR >= 4)
        builder.Inliner = llvm::createAlwaysInlinerLegacyPass();
#else
        builder.Inliner = llvm::createAlwaysInlinerPass();
#endif
        builder.populateModulePassManager(*pm);
        // fpm->add(new llvm::DataLayoutPass());
        builder.populateFunctionPassManager(*fpm);
        fpm->run(*F);
        fpm->run(*FLOOP);
        pm->run(*altModule);

        // Create the JIT.  This takes ownership of the module.

        if (!TheExecutionEngine) {
            std::cerr << "Could not create ExecutionEngine: " << ErrStr << std::endl;
            exit(1);
        }

        TheExecutionEngine->finalizeObject();
        void *fp = TheExecutionEngine->getPointerToFunction(F);
        void *fpLoop = TheExecutionEngine->getPointerToFunction(FLOOP);
        if (desireFP) {
            _llvmEvalFP = std::make_unique<LLVMEvaluationContext<double>>();
            _llvmEvalFP->init(fp, fpLoop, dimDesired);
        } else {
            _llvmEvalStr = std::make_unique<LLVMEvaluationContext<char *>>();
            _llvmEvalStr->init(fp, fpLoop, dimDesired);
        }

        if (Expression::debugging) {
#ifdef DEBUG
            std::cerr << "Pre verified LLVM byte code " << std::endl;
            altModule->print(llvm::errs(), nullptr);
#endif
        }

        return true;
    }

    std::string getUniqueName() const
    {
        std::ostringstream o;
        o << std::setbase(16) << reinterpret_cast<uintptr_t>(this);
        return ("_" + o.str());
    }
};

#else // no LLVM support
class LLVMEvaluator
{
public:
    static void unsupported()
    {
        assert(false && "LLVM is not enabled in build");
    }
    static const char *evalStr(VarBlock *)
    {
        unsupported();
        return nullptr;
    }
    static const double *evalFP(VarBlock *)
    {
        unsupported();
        return nullptr;
    }
    static bool prepLLVM(ExprNode *, ExprType)
    {
        unsupported();
        return false;
    }
    static void evalMultiple(VarBlock *, int, size_t, size_t)
    {
        unsupported();
    }
    void debugPrint()
    {
    }
};
#endif

} // end namespace KSeExpr
